-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Knowledge distillation for vision guide #25619
Knowledge distillation for vision guide #25619
Conversation
@sayakpaul I changed the setup and didn't observe a lot of difference, but I felt like it would be still cool to show how to distill a model. WDYT? |
cc @rafaelpadilla for reference |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic to see knowledge distillation being discussed—such an exciting topic! 🚀
Just shared a few comments and suggestions that might enhance readability. Most are related to writing style.
I appreciate the straightforward example you've provided. 👍
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
…cation.md Co-authored-by: Rafael Padilla <[email protected]>
…cation.md Co-authored-by: Rafael Padilla <[email protected]>
…cation.md Co-authored-by: Rafael Padilla <[email protected]>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
…cation.md Co-authored-by: Rafael Padilla <[email protected]>
@rafaelpadilla @NielsRogge can we merge this if this looks good? |
Yes, it's OK to me. |
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
…cation.md Co-authored-by: NielsRogge <[email protected]>
…cation.md Co-authored-by: NielsRogge <[email protected]>
…cation.md Co-authored-by: NielsRogge <[email protected]>
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
dataset = load_dataset("beans") | ||
``` | ||
|
||
We can use either of the processors given they return the same output. We will use `map()` method of `dataset` to apply the preprocessing to every split of the dataset. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sentence is actually not true, ResNet and MobileNet each have their own image processors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They do return the same thing because processor just does preprocessing on same resolution. Check this out.
from transformers import AutoFeatureExtractor
from PIL import Image
import requests
import numpy as np
teacher_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
student_extractor = AutoFeatureExtractor.from_pretrained("google/mobilenet_v2_1.4_224")
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
sample = Image.open(requests.get(url, stream=True).raw)
np.array_equal(teacher_extractor(sample),student_extractor(sample))
# True
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for writing this up! ❤️
…cation.md Co-authored-by: NielsRogge <[email protected]>
…cation.md Co-authored-by: NielsRogge <[email protected]>
…cation.md Co-authored-by: NielsRogge <[email protected]>
…cation.md Co-authored-by: NielsRogge <[email protected]>
…cation.md Co-authored-by: NielsRogge <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work on the guide! While reading it I had a few questions that I feel other folks may have, and it would be great to address them :)
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
processed_datasets = dataset.map(process, batched=True) | ||
``` | ||
|
||
Essentially, we want the student model (a randomly initialized MobileNet) to mimic the teacher model (pre-trained ResNet). To achieve this, we first get the logits output by the teacher and the student. Then, we divide each of them by the parameter `temperature`, which controls the importance of each soft target. We will use the KL loss to compute the divergence between the student and teacher. A parameter called `lambda` weighs the importance of the distillation loss. In this example, we will use `temperature=5` and `lambda=0.5`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be cool to link KL loss to some page that gives a definition of what that is for people who are not familiar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you're customizing the Trainer
, it would also be nice to link to this page https://huggingface.co/docs/transformers/en/main_classes/trainer#trainer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first sentence would be great to have somewhere in the introduction - how the distillation works. Something like: "To distill knowledge from one model to another, we take a pre-trained teacher model, and randomly initialize a student model. Next, we train the student model to minimize the difference between its outputs and the teacher's outputs, thus making it mimic the behavior. "
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
docs/source/en/tasks/knowledge_distillation_for_image_classification.md
Outdated
Show resolved
Hide resolved
|
||
```python | ||
trainer.evaluate(processed_datasets["test"]) | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe also push the final model to hub?
trainer.push_to_hub()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the final model is pushed already when we set push_to_hub
to True (I also have save strategy enabled for every epoch so it's triggered every epoch as well), no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK trainer.push_to_hub()
also creates a basic model card, e.g. with metrics, and some training results.
…cation.md Co-authored-by: Maria Khalusova <[email protected]>
…cation.md Co-authored-by: Maria Khalusova <[email protected]>
…cation.md Co-authored-by: Maria Khalusova <[email protected]>
…cation.md Co-authored-by: Maria Khalusova <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for iterating on this! This revision looks fantastic :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. :)
@LysandreJik can you give a review or ask for another reviewer if needed? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @merveenoyan!
Please resolve the merge conflicts and merge @merveenoyan |
* Knowledge distillation for vision guide * Update knowledge_distillation_for_image_classification.md * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: Rafael Padilla <[email protected]> * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: Rafael Padilla <[email protected]> * Iterated on Rafael's comments * Added to toctree * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: Rafael Padilla <[email protected]> * Addressed comments * Update knowledge_distillation_for_image_classification.md * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: Rafael Padilla <[email protected]> * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: NielsRogge <[email protected]> * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: NielsRogge <[email protected]> * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: NielsRogge <[email protected]> * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: NielsRogge <[email protected]> * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: NielsRogge <[email protected]> * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: NielsRogge <[email protected]> * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: NielsRogge <[email protected]> * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: NielsRogge <[email protected]> * Update knowledge_distillation_for_image_classification.md * Update knowledge_distillation_for_image_classification.md * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: Maria Khalusova <[email protected]> * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: Maria Khalusova <[email protected]> * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: Maria Khalusova <[email protected]> * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: Maria Khalusova <[email protected]> * Address comments * Update knowledge_distillation_for_image_classification.md * Explain KL Div --------- Co-authored-by: Rafael Padilla <[email protected]> Co-authored-by: NielsRogge <[email protected]> Co-authored-by: Maria Khalusova <[email protected]>
This is a draft PR that I opened in the past on KD guide for CV, but I accidentally removed my fork. I prioritized TGI docs so this PR might stay stale for a while, I will ask for a review after I iterate over comments left by @sayakpaul in my previous PR. (Mainly training MobileNet with random initial weights and not with pre-trained weights from transformers)